{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-1.10.0+cu113.html\n",
    "!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-1.10.0+cu113.html\n",
    "!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os.path as osp\n",
    "import numpy as np\n",
    "import torch\n",
    "from torch_geometric.datasets import AMiner\n",
    "from torch_geometric.nn import MetaPath2Vec"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# MetaPath2Vec\n",
    "\n",
    "[paper](https://ericdongyx.github.io/papers/KDD17-dong-chawla-swami-metapath2vec.pdf)  \n",
    "[code](https://github.com/rusty1s/pytorch_geometric/blob/master/examples/metapath2vec.py)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load the dataset\n",
    "path = osp.join('..', 'data', 'AMiner')\n",
    "dataset = AMiner(path)\n",
    "data = dataset[0]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(type(data.edge_index_dict))\n",
    "print(data.edge_index_dict[('paper', 'written by', 'author')])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(type(data.num_nodes_dict))\n",
    "print(data.num_nodes_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(type(data.y_dict))\n",
    "print(data.y_dict[\"venue\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(type(data.y_index_dict))\n",
    "print(data.y_index_dict[\"venue\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# move the data to cpu or GPU\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "device = \"cpu\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define the model\n",
    "\n",
    "metapath = [\n",
    "    ('author', 'wrote', 'paper'),\n",
    "    ('paper', 'published in', 'venue'),\n",
    "    ('venue', 'published', 'paper'),\n",
    "    ('paper', 'written by', 'author'),\n",
    "]\n",
    "\n",
    "\n",
    "model = MetaPath2Vec(data.edge_index_dict, \n",
    "                     embedding_dim=128,\n",
    "                     metapath=metapath,\n",
    "                     walk_length=5, \n",
    "                     context_size=3,\n",
    "                     walks_per_node=3,\n",
    "                     num_negative_samples=1,\n",
    "                     sparse=True\n",
    "                    ).to(device)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# use the loader to build a loader\n",
    "loader = model.loader(batch_size=128, shuffle=True, num_workers=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for idx, (pos_rw, neg_rw) in enumerate(loader):\n",
    "    if idx == 10: break\n",
    "    print(idx, pos_rw.shape, neg_rw.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(pos_rw[0],neg_rw[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Inizialize optimizer\n",
    "optimizer = torch.optim.SparseAdam(list(model.parameters()), lr=0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(epoch, log_steps=500, eval_steps=1000):\n",
    "    model.train()\n",
    "\n",
    "    total_loss = 0\n",
    "    for i, (pos_rw, neg_rw) in enumerate(loader):\n",
    "        optimizer.zero_grad()\n",
    "        loss = model.loss(pos_rw.to(device), neg_rw.to(device))\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        total_loss += loss.item()\n",
    "        if (i + 1) % log_steps == 0:\n",
    "            print((f'Epoch: {epoch}, Step: {i + 1:05d}/{len(loader)}, '\n",
    "                   f'Loss: {total_loss / log_steps:.4f}'))\n",
    "            total_loss = 0\n",
    "\n",
    "        if (i + 1) % eval_steps == 0:\n",
    "            acc = test()\n",
    "            print((f'Epoch: {epoch}, Step: {i + 1:05d}/{len(loader)}, '\n",
    "                   f'Acc: {acc:.4f}'))\n",
    "\n",
    "@torch.no_grad()\n",
    "def test(train_ratio=0.1):\n",
    "    model.eval()\n",
    "\n",
    "    z = model('author', batch=data.y_index_dict['author'])\n",
    "    y = data.y_dict['author']\n",
    "\n",
    "    perm = torch.randperm(z.size(0))\n",
    "    train_perm = perm[:int(z.size(0) * train_ratio)]\n",
    "    test_perm = perm[int(z.size(0) * train_ratio):]\n",
    "\n",
    "    return model.test(z[train_perm], y[train_perm], z[test_perm],\n",
    "                      y[test_perm], max_iter=150)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for epoch in range(1, 2):\n",
    "    train(epoch)\n",
    "    acc = test()\n",
    "    print(f'Epoch: {epoch}, Accuracy: {acc:.4f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# load the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "loaded_model = MetaPath2Vec(data.edge_index_dict, \n",
    "                     embedding_dim=128,\n",
    "                     metapath=metapath,\n",
    "                     walk_length=5, \n",
    "                     context_size=3,\n",
    "                     walks_per_node=3,\n",
    "                     num_negative_samples=1,\n",
    "                     sparse=True\n",
    "                    ).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(loaded_model.embedding.weight[1][:5])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load the model\n",
    "loaded_model.load_state_dict(torch.load(\"mymodel\").detach().cpu())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# move the model to cpu\n",
    "file = torch.load('mymodel', map_location=lambda storage, loc: storage)\n",
    "loaded_model.load_state_dict(file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(loaded_model.embedding.weight[1][:5])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "z_venue = loaded_model('venue', batch=data.y_index_dict['venue']).detach().numpy()\n",
    "z_auth = loaded_model('author', batch=data.y_index_dict['author']).detach().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "z_venue = z_venue[0:100]\n",
    "z_auth = z_auth[0:100]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import umap\n",
    "\n",
    "embedder = umap.UMAP().fit(data,y)\n",
    "\n",
    "z_venue_2d = umap.UMAP().fit_transform(z_venue)\n",
    "z_auth_2d = umap.UMAP().fit_transform(z_auth)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "\n",
    "plt.figure(figsize=(6,6))\n",
    "plt.scatter(z_auth_2d[:,0],z_auth_2d[:,1],color=\"red\",alpha=0.5,label=\"author\")\n",
    "plt.scatter(z_venue_2d[:,0],z_venue_2d[:,1],color=\"blue\",alpha=0.5,label=\"venue\")\n",
    "plt.legend()\n",
    "plt.title(\"2D embedding\")\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
